{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Functions:\n",
    "def reset_graph(seed=42):\n",
    "    tf.reset_default_graph()\n",
    "    tf.set_random_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    \n",
    "# def squash(vector):\n",
    "#     epsilon = 0.001\n",
    "#     norm_vector = tf.sqrt(tf.reduce_sum(tf.square(vector)))\n",
    "#     norm_vector_epsilon = tf.sqrt(tf.reduce_sum(tf.square(vector))+epsilon)\n",
    "    \n",
    "#     return norm_vector/(norm_vector+1)*(tf.norm(vector)/tf.norm(vector))\n",
    "\n",
    "\n",
    "def squash(vector):\n",
    "    '''Squashing function.\n",
    "    Args:\n",
    "        vector: A 4-D tensor with shape [batch_size, num_caps, vec_len, 1],\n",
    "    Returns:\n",
    "        A 4-D tensor with the same shape as vector but\n",
    "        squashed in 3rd and 4th dimensions.\n",
    "    '''\n",
    "    vec_abs = tf.sqrt(tf.reduce_sum(tf.square(vector)))  # a scalar\n",
    "    scalar_factor = tf.square(vec_abs) / (1 + tf.square(vec_abs))\n",
    "    vec_squashed = scalar_factor * tf.divide(vector, vec_abs)  # element-wise\n",
    "    return(vec_squashed)\n",
    "\n",
    "def routing(u_hat, b_IJ, num_iter):\n",
    "    # Stopping the routing:\n",
    "    u_hat_stopped = tf.stop_gradient(u_hat, name='u_hat_stopped')\n",
    "    print('u_hat shape: ',u_hat_stopped.shape)\n",
    "    # Routing\n",
    "    with tf.name_scope('routing'):\n",
    "        for r_iter in range(num_iter):\n",
    "            c = tf.nn.softmax(b_IJ,axis=2)\n",
    "            #assert c.get_shape().as_list() == [5000,1152,10,1,1]\n",
    "            if r_iter == num_iter-1:\n",
    "                s_j = tf.reduce_sum(tf.multiply(c,u_hat),axis = 1, keepdims = True)\n",
    "                v = squash(s_j)\n",
    "            else:\n",
    "                s_j = tf.reduce_sum(tf.multiply(c,u_hat_stopped),axis = 1,keepdims=True)\n",
    "                v = squash(s_j)\n",
    "                v_tiled = tf.tile(v,[1, 1152,1,1,1])\n",
    "                a = tf.matmul(u_hat_stopped, v_tiled, transpose_a=True)\n",
    "                b_IJ = b_IJ + a\n",
    "#         print('c shape: ',c.shape)\n",
    "#         print('s_j shape: ',s_j.shape)\n",
    "#         print('v shape: ',v.shape)\n",
    "#         print('a shape: ',a.shape)\n",
    "#         print('b_IJ shape: ',b_IJ.shape)\n",
    "    return v,b_IJ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Downloading mnist dataset:\n",
    "(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "# Train set\n",
    "X_train = X_train.astype(np.float32).reshape(-1, 28, 28,1) / 255.0\n",
    "y_train = y_train.astype(np.int32)\n",
    "# Test set:\n",
    "X_test = X_test.astype(np.float32).reshape(-1, 28, 28,1) / 255.0\n",
    "y_test = y_test.astype(np.int32)\n",
    "training_size = 2500\n",
    "# Validation set:\n",
    "X_train, X_valid = X_train[:training_size], X_train[training_size:]\n",
    "y_train, y_valid = y_train[:training_size], y_train[training_size:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Some constants:\n",
    "batch_size = 500 #None#X_train.shape[0]\n",
    "n_batches = int(X_train.shape[0]/batch_size)\n",
    "n_inputs = 28*28\n",
    "channels = 1\n",
    "#n_output_conv1 = (20,20,256)\n",
    "height,width = 28,28\n",
    "# variables:\n",
    "num_iter = 5\n",
    "n_outputs = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize graph:\n",
    "reset_graph()\n",
    "# Placeholders:\n",
    "X = tf.placeholder(shape=(batch_size, height, width, channels), dtype=tf.float32)\n",
    "y = tf.placeholder(tf.int32, shape=(batch_size), name=\"y\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(500, 20, 20, 256)\n"
     ]
    }
   ],
   "source": [
    "# First Conv layer:\n",
    "with tf.name_scope('cnn'):\n",
    "    conv = tf.layers.conv2d(X, filters=256, kernel_size=9, strides=[1,1], padding='VALID')\n",
    "    print(conv.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "caps shape:  (500, 6, 6, 256)\n",
      "u_i shape:  (500, 1152, 8, 1)\n"
     ]
    }
   ],
   "source": [
    "# Problems with the implementation?\n",
    "# First Capsule LAyer:\n",
    "with tf.name_scope('caps'):\n",
    "    caps = tf.layers.conv2d(conv,filters=256,kernel_size=9,strides=[2,2],padding='VALID')\n",
    "    print('caps shape: ',caps.shape)\n",
    "    u_i = tf.reshape(caps, shape=[-1,32*6*6,8,1])\n",
    "    #caps2 = tf.layers.conv2d(caps1,filters=8,kernel_size=9,strides=[2,2],padding='VALID')\n",
    "    u_i = squash(u_i)\n",
    "    print('u_i shape: ',u_i.shape)\n",
    "    #a_caps1 = squash(caps1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# # First Capsule LAyer:\n",
    "# with tf.name_scope('caps'):\n",
    "#     caps_list = []\n",
    "#     for i in range(8):\n",
    "#         caps_i = tf.layers.conv2d(conv,filters=32,kernel_size=9,strides=[2,2],padding='VALID')\n",
    "#         caps_i = tf.reshape(caps_i, [batch_size,6,6,32,1])\n",
    "#         caps_list.append(caps_i)\n",
    "        \n",
    "#     u_i = tf.concat(caps_list,axis = -1)\n",
    "#     print('caps shape: ',u_i.shape)\n",
    "#     u_i = tf.reshape(u_i, shape=[-1,32*6*6,8,1])\n",
    "#     #caps2 = tf.layers.conv2d(caps1,filters=8,kernel_size=9,strides=[2,2],padding='VALID')\n",
    "#     u_i = squash(u_i)\n",
    "#     print('u_i shape: ',u_i.shape)\n",
    "#     #a_caps1 = squash(caps1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W shape:  (500, 1152, 10, 8, 16)\n",
      "u_i shape:  (500, 1152, 10, 8, 1)\n",
      "u_hat shape:  (500, 1152, 10, 16, 1)\n",
      "b_IJ shape:  (500, 1152, 10, 1, 1)\n",
      "u_hat shape:  (500, 1152, 10, 16, 1)\n"
     ]
    }
   ],
   "source": [
    "# Routing:\n",
    "\n",
    "with tf.variable_scope('final_layer'):\n",
    "    w_initializer = np.random.normal(size=[1, 1152, 10, 8, 16], scale=0.01)\n",
    "    W = tf.Variable(w_initializer, dtype=tf.float32)\n",
    "\n",
    "    # repeat W with batch_size times to shape [batch_size, 1152, 8, 16]\n",
    "    W = tf.tile(W, [batch_size, 1, 1, 1,1]) # -1 instead of Batch size\n",
    "    print('W shape: ',W.shape)\n",
    "    # calc u_ahat\n",
    "    # [8, 16].T x [8, 1] => [16, 1] => [batch_size, 1152, 16, 1]\n",
    "    u_i = tf.reshape(u_i, shape=(batch_size, -1, 1, u_i.shape[-2].value, 1)) # -1 instead of Batch size\n",
    "    u_i = tf.tile(u_i, [1,1,10,1,1])\n",
    "    print('u_i shape: ',u_i.shape)\n",
    "    u_hat = tf.matmul(W, u_i, transpose_a=True)\n",
    "    print('u_hat shape: ',u_hat.shape)\n",
    "    \n",
    "    with tf.variable_scope('routing'):\n",
    "        # Initialize constants:\n",
    "        b_IJ = tf.zeros([batch_size, u_i.shape[1].value, 10, 1, 1], dtype=np.float32)\n",
    "        print('b_IJ shape: ',b_IJ.shape)\n",
    "        v,b_IJ = routing(u_hat,b_IJ,num_iter)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "#def calculate_margin_loss(y_batch,v,n_outputs):\n",
    "m_pos = 0.9\n",
    "m_neg = 0.1\n",
    "margin_loss = 0\n",
    "lambda_const = 0.5\n",
    "\n",
    "with tf.name_scope('margin_loss'):\n",
    "    v =tf.squeeze(v)\n",
    "    #v_norm = tf.map_fn(lambda x: tf.norm(x,axis=1), v)\n",
    "    v_norm = tf.sqrt(tf.reduce_sum(tf.square(v),axis=2))\n",
    "    #v_softmax = tf.nn.softmax(v_norm, axis =1)\n",
    "    #y_pred = tf.argmax(v_norm,axis=1)\n",
    "    #t_k = tf.equal(y_pred,y_train)\n",
    "    t_k = tf.one_hot(y,n_outputs)\n",
    "\n",
    "    # Loss:\n",
    "    max_l = tf.maximum(tf.cast(0,tf.float32),tf.square(m_pos - v_norm))\n",
    "    min_l = tf.maximum(tf.cast(0,tf.float32),tf.square(v_norm - m_neg))\n",
    "\n",
    "    margin_loss = tf.multiply(t_k,tf.square(max_l)) + lambda_const*tf.multiply((1-t_k),tf.square(min_l))\n",
    "#    return margin_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Regularizer Decoder:\n",
    "\n",
    "with tf.name_scope('mask'):\n",
    "    # Masking:\n",
    "    v_list = []\n",
    "    for i,j in zip(range(batch_size),y_train):   # y_train ?\n",
    "        v_list.append(tf.reshape(tf.squeeze(v)[i][j,:],[1,16]))\n",
    "    v_masked = tf.concat(v_list,axis=0) \n",
    "    \n",
    "with tf.name_scope('decoder'):\n",
    "    # 2 FC Relu:\n",
    "    dec1 = tf.layers.dense(inputs=v_masked, units=512, activation=tf.nn.relu)\n",
    "    dec2 = tf.layers.dense(inputs=dec1, units=512, activation=tf.nn.relu)\n",
    "    # 1 FC Sigmoid:\n",
    "    dec3 =  tf.layers.dense(inputs=dec2, units=n_inputs, activation=tf.nn.sigmoid)\n",
    "    loss_reg = tf.norm(tf.reshape(X, [batch_size,n_inputs])-dec3)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Total loss:\n",
    "with tf.name_scope('Loss'):\n",
    "    #margin_loss = calculate_margin_loss(y_batch,v,n_outputs)\n",
    "    loss = tf.reduce_sum(margin_loss) + loss_reg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Backpropagation:\n",
    "with tf.name_scope('optimizer'):\n",
    "    optimizer = tf.train.AdamOptimizer()\n",
    "    training_op = optimizer.minimize(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluation:\n",
    "#with tf.name_scope('evaluate'):\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'n_batches' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-17-baeae284b79c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      8\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m         \u001b[0mshuffled_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpermutation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m         \u001b[0mX_batches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray_split\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mshuffled_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_batches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m         \u001b[0my_batches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray_split\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mshuffled_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_batches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'n_batches' is not defined"
     ]
    }
   ],
   "source": [
    "init = tf.global_variables_initializer()\n",
    "epochs = 5\n",
    "train_size = X_train.shape[0]\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    init.run()\n",
    "    \n",
    "    for epoch in range(epochs):\n",
    "        shuffled_idx = np.random.permutation(train_size)\n",
    "        X_batches = np.array_split(X_train[shuffled_idx], n_batches)\n",
    "        y_batches = np.array_split(y_train[shuffled_idx], n_batches)\n",
    "\n",
    "        for X_batch,y_batch in zip(X_batches, y_batches):\n",
    "            sess.run(training_op, feed_dict = {X:X_batch, y:y_batch})\n",
    "        print('Epoch {} completed'.format(epoch))\n",
    "        \n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "\n",
    "# Create example y_hat.\n",
    "#sess.run(tf.global_variables_initializer())\n",
    "#y_pred_run = sess.run(y_pred, feed_dict ={X:X_train})\n",
    "#print(sess.run(t_k[0],feed_dict = {X:X_train, y : y_train}))\n",
    "sess.run(tf.global_variables_initializer())\n",
    "#y_sess = sess.run(y_pred, feed_dict={X:X_train, y:y_train})\n",
    "tk_sess = sess.run(t_k, feed_dict={X:X_train, y:y_train})\n",
    "\n",
    "\n",
    "sess.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "loss = 0\n",
    "for i in range(batch_size):\n",
    "    for k in range(n_outputs):\n",
    "        if y_pred[i]\n",
    "        loss += \n",
    "    \n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "Y = tf.one_hot(y_train, depth=10, axis=1, dtype=tf.float32)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "masked_v = tf.multiply(tf.squeeze(v), tf.reshape(Y, (-1, 10, 1)))\n",
    "v_length = tf.sqrt(tf.reduce_sum(tf.square(v), axis=2, keep_dims=True) + 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "masked_v.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "x =tf.boolean_mask(tf.reshape(v,[batch_size,10,16]),pd.get_dummies(y_train).values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "\n",
    "# Create example y_hat.\n",
    "sess.run(tf.global_variables_initializer())\n",
    "Y_run = sess.run(Y)\n",
    "\n",
    "sess.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lower_triangular_ones = tf.constant(np.tril(np.ones([30,30])),dtype=tf.float32)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "L.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "shape = [2, 2, 2, 10] \n",
    "L = np.arange(np.prod(shape))\n",
    "L = np.reshape(L, shape)\n",
    "\n",
    "indices = [0, 2, 3, 8]\n",
    "axis = -1 # last dimension\n",
    "\n",
    "def gather_axis(params, indices, axis=0):\n",
    "    return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis)\n",
    "\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    partL = sess.run(gather_axis(L, indices, axis))\n",
    "    print(partL.shape)\n",
    "    #test = sess.run(tf.gather(tf.unstack(L, axis=axis),indices))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "v_test = tf.reshape(tf.range(5000*10*16),[5000,10,16])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "sess = tf.InteractiveSession()\n",
    "v_run = sess.run(v_test)\n",
    "v_mask = sess.run(tf.boolean_mask(v_test,mask=y_train,axis=0))\n",
    "#v_masked = sess.run(gather_axis(v_test,y_train,0))\n",
    "print(v_mask.shape)\n",
    "sess.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "list_v =[]\n",
    "for i in range(v_run.shape[0]):\n",
    "    list_v.append(tf.reshape(v_test[i,y_train[i],:],[1,16]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "list_v[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def routing(u_hat, b):\n",
    "    v = []\n",
    "    #b_shape = b.get_shape().as_list()\n",
    "    u_hat_stopped = tf.stop_gradient(u_hat, name='u_hat_stopped')\n",
    "    for j in range(10):   \n",
    "        size_splits = [j, 1, b_shape[2] - j - 1]\n",
    "        \n",
    "        for r_iter in range(MAX_ITER):\n",
    "            c = tf.nn.softmax(b,axis=2)\n",
    "\n",
    "            assert c.get_shape() == [1, 1152, 10, 1]\n",
    "            \n",
    "            c = tf.tile(c,[5000,1,1,1])\n",
    "#             b_il, b_ij, b_ir = tf.split(b, size_splits, axis=2)\n",
    "#             c_il, c_ij, c_ir = tf.split(c, size_splits, axis=2)\n",
    "        \n",
    "            # Calculating c_ij\n",
    "            #c_ij = c[:,:,j,:]\n",
    "            #new_shape_c_ij = c_ij.get_shape().as_list() + [1]\n",
    "            #c_ij = tf.reshape(c_ij, new_shape_c_ij)\n",
    "            \n",
    "            #assert c_ij.get_shape() == [1, 1152, 1, 1]\n",
    "            \n",
    "            # Computing s_j\n",
    "            if j == MAX_ITER-1:\n",
    "                s_j = tf.reduce_sum(tf.multiply(c,u_hat))\n",
    "                \n",
    "                \n",
    "                \n",
    "                \n",
    "                s_j = tf.reduce_sum(tf.multiply(c_ij, u_hat),axis=1, keepdims=True)\n",
    "            assert s_j.get_shape() == [5000, 1, 16, 1]\n",
    "            \n",
    "            # computing v_j\n",
    "            v_j = squash(s_j)\n",
    "            # updating b_ij:\n",
    "            \n",
    "            v_j_tiled = tf.tile(v_j, [1,1152,1,1])\n",
    "            a_ij = tf.reduce_sum(tf.matmul(u_hat,v_j_tiled, transpose_a=True),axis = 0, keepdims=True)\n",
    "            #print(v_j.shape)\n",
    "            #print(a_ij.shape)\n",
    "            \n",
    "            b_ij += a_ij\n",
    "            #print('b_ij shape: ',b_ij.shape)\n",
    "            b = tf.concat([b_il, b_ij, b_ir] ,axis =2)\n",
    "        v.append(v_j)\n",
    "    v = tf.concat(v,axis=1)\n",
    "    \n",
    "    print('c_ij shape: ',c_ij.shape)\n",
    "    print('s_j shape: ',s_j.shape)\n",
    "    print('v_j shape: ',v_j.shape)\n",
    "    print('v_j_tiled shape: ',v_j_tiled.shape)\n",
    "    print('a_ij shape: ',a_ij.shape)\n",
    "    print('v shape: ',v.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "a_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with tf.name_scope('digit_caps'):\n",
    "    W = tf.Variable( tf.random_uniform((8,16),seed = 1), name=\"theta\",dtype=tf.float32)\n",
    "    u_ji = tf.reshape(tf.matmul(tf.reshape(u_i,[5000*1152,8]),W),[5000,1152,16])\n",
    "    v_j = routing(u_ji, num_iter)\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Testing:\n",
    "sess = tf.InteractiveSession()\n",
    "\n",
    "# We can just use 'c.eval()' without passing 'sess'\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "#print(a.eval(feed_dict=X_test))\n",
    "print(c.eval())\n",
    "sess.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Testing class Caps Layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class CapsLayer(object):\n",
    "    def __init__(self, input_tensor, num_capsules, digit_layer = False):\n",
    "        self.input = input_tensor\n",
    "        self.num_capsules = num_capsules\n",
    "        \n",
    "    def __call__(self, num_filters, kernel_size, strides):\n",
    "        # In case it is a primary layer:\n",
    "        if digit_layer == False:\n",
    "            conv = tf.layers.conv2d(inputs=self.input, filters=num_filters, \\\n",
    "                                    kernel_size=kernel_size, activation=tf.nn.relu, \\\n",
    "                                    strides=strides, padding = 'SAME' )\n",
    "            caps = self.squash(conv)\n",
    "            return caps\n",
    "        # In case it is a final layer:\n",
    "        else:\n",
    "            fcaps = self.routing(self.input)\n",
    "            return fcaps\n",
    "            \n",
    "        \n",
    "    def routing(self,capsule):\n",
    "    \n",
    "        \n",
    "    \n",
    "    def create_list_capsules(num_capsules):\n",
    "        for i in range(num_capsules)\n",
    "            if i == 0:\n",
    "                self.list_capsules.append(tf.layers.conv2d(input_tensor, \\\n",
    "                                                           filters=num_filters, \\\n",
    "                                                           kernel_size=kernel_size, \\\n",
    "                                                           strides=[2,2], \\\n",
    "                                                           padding='VALID'))\n",
    "            else:\n",
    "                self.list_capsules.append(create_capsule)\n",
    "        \n",
    "    def create_capsule(input_tensor, kernel_size, strides):\n",
    "        return tf.layers.conv2d(input_capsule, filters=num_filters, kernel_size=kernel_size, strides=[2,2], padding='VALID')\n",
    "    \n",
    "    def squashing(vector):\n",
    "        return tf.norm(vector)/(tf.norm(vector)+1)*(tf.norm(vector)/tf.norm(vector))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
